{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training Regression - Reaction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/training_regression_reaction.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install chemprop from GitHub if running in Google Colab\n",
    "import os\n",
    "\n",
    "if os.getenv(\"COLAB_RELEASE_TAG\"):\n",
    "    try:\n",
    "        import chemprop\n",
    "    except ImportError:\n",
    "        !git clone https://github.com/chemprop/chemprop.git\n",
    "        %cd chemprop\n",
    "        !pip install .\n",
    "        %cd examples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from lightning import pytorch as pl\n",
    "from pathlib import Path\n",
    "\n",
    "from chemprop import data, featurizers, models, nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Change data inputs here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "chemprop_dir = Path.cwd().parent\n",
    "input_path = chemprop_dir / \"tests\" / \"data\" / \"regression\" / \"rxn\" / \"rxn.csv\"\n",
    "num_workers = 0  # number of workers for dataloader. 0 means using main process for data loading\n",
    "smiles_column = 'smiles'\n",
    "target_columns = ['ea']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>ea</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[O:1]([C:2]([C:3]([C:4](=[O:5])[C:6]([O:7][H:1...</td>\n",
       "      <td>8.898934</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[C:1]1([H:8])([H:9])[O:2][C@@:3]2([H:10])[C@@:...</td>\n",
       "      <td>5.464328</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[C:1]([C@@:2]1([H:11])[C@@:3]2([H:12])[C:4]([H...</td>\n",
       "      <td>5.270552</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[C:1]([O:2][C:3]([C@@:4]([C:5]([H:14])([H:15])...</td>\n",
       "      <td>8.473006</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>[C:1]([C:2]#[C:3][C:4]([C:5](=[O:6])[H:12])([H...</td>\n",
       "      <td>5.579037</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95</th>\n",
       "      <td>[C:1]([C:2]([C:3]([H:12])([H:13])[H:14])([C:4]...</td>\n",
       "      <td>9.295665</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>96</th>\n",
       "      <td>[O:1]=[C:2]([C@@:3]1([H:9])[C:4]([H:10])([H:11...</td>\n",
       "      <td>7.753442</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>97</th>\n",
       "      <td>[C:1]([C@@:2]1([H:11])[C@@:3]2([H:12])[C:4]([H...</td>\n",
       "      <td>10.650215</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>[C:1]1([H:8])([H:9])[C@@:2]2([H:10])[N:3]1[C:4...</td>\n",
       "      <td>10.138945</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>99</th>\n",
       "      <td>[C:1]([C@@:2]1([C:3]([C:4]([O:5][H:15])([H:13]...</td>\n",
       "      <td>6.979934</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>100 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               smiles         ea\n",
       "0   [O:1]([C:2]([C:3]([C:4](=[O:5])[C:6]([O:7][H:1...   8.898934\n",
       "1   [C:1]1([H:8])([H:9])[O:2][C@@:3]2([H:10])[C@@:...   5.464328\n",
       "2   [C:1]([C@@:2]1([H:11])[C@@:3]2([H:12])[C:4]([H...   5.270552\n",
       "3   [C:1]([O:2][C:3]([C@@:4]([C:5]([H:14])([H:15])...   8.473006\n",
       "4   [C:1]([C:2]#[C:3][C:4]([C:5](=[O:6])[H:12])([H...   5.579037\n",
       "..                                                ...        ...\n",
       "95  [C:1]([C:2]([C:3]([H:12])([H:13])[H:14])([C:4]...   9.295665\n",
       "96  [O:1]=[C:2]([C@@:3]1([H:9])[C:4]([H:10])([H:11...   7.753442\n",
       "97  [C:1]([C@@:2]1([H:11])[C@@:3]2([H:12])[C:4]([H...  10.650215\n",
       "98  [C:1]1([H:8])([H:9])[C@@:2]2([H:10])[N:3]1[C:4...  10.138945\n",
       "99  [C:1]([C@@:2]1([C:3]([C:4]([O:5][H:15])([H:13]...   6.979934\n",
       "\n",
       "[100 rows x 2 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_input = pd.read_csv(input_path)\n",
    "df_input"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load smiles and targets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array(['[O:1]([C:2]([C:3]([C:4](=[O:5])[C:6]([O:7][H:15])([H:13])[H:14])([H:11])[H:12])([H:9])[H:10])[H:8]>>[C:3](=[C:4]=[O:5])([H:11])[H:12].[C:6]([O:7][H:15])([H:8])([H:13])[H:14].[O:1]=[C:2]([H:9])[H:10]',\n",
       "        '[C:1]1([H:8])([H:9])[O:2][C@@:3]2([H:10])[C@@:4]3([H:11])[O:5][C@:6]1([H:12])[C@@:7]23[H:13]>>[C:1]1([H:8])([H:9])[O:2][C:3]([H:10])=[C:7]([H:13])[C@:6]1([O+:5]=[C-:4][H:11])[H:12]',\n",
       "        '[C:1]([C@@:2]1([H:11])[C@@:3]2([H:12])[C:4]([H:13])([H:14])[C:5]([H:15])=[C:6]([H:16])[C@@:7]12[H:17])([H:8])([H:9])[H:10]>>[C:1]([C@@:2]1([H:11])[C:3]([H:12])([H:13])[C:4]([H:14])=[C:5]([H:15])[C:6]([H:16])=[C:7]1[H:17])([H:8])([H:9])[H:10]',\n",
       "        '[C:1]([O:2][C:3]([C@@:4]([C:5]([H:14])([H:15])[H:16])([C:6]([O:7][H:19])([H:17])[H:18])[H:13])([H:11])[H:12])([H:8])([H:9])[H:10]>>[C-:1]([O+:2]=[C:3]([C@@:4]([C:5]([H:14])([H:15])[H:16])([C:6]([O:7][H:19])([H:17])[H:18])[H:13])[H:12])([H:8])[H:10].[H:9][H:11]',\n",
       "        '[C:1]([C:2]#[C:3][C:4]([C:5](=[O:6])[H:12])([H:10])[H:11])([H:7])([H:8])[H:9]>>[C:1]([C:2](=[C:3]=[C:4]([H:10])[H:11])[C:5](=[O:6])[H:12])([H:7])([H:8])[H:9]'],\n",
       "       dtype=object),\n",
       " array([[8.8989335 ],\n",
       "        [5.46432769],\n",
       "        [5.27055228],\n",
       "        [8.47300569],\n",
       "        [5.57903696]]))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "smis = df_input.loc[:, smiles_column].values\n",
    "ys = df_input.loc[:, target_columns].values\n",
    "\n",
    "smis[:5], ys[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get datapoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_data = [data.ReactionDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Perform data splitting for training, validation, and testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "mols = [d.rct for d in all_data]  # Can either split by reactants (.rct) or products (.pdt)\n",
    "train_indices, val_indices, test_indices = data.make_split_indices(mols, \"random\", (0.8, 0.1, 0.1))\n",
    "train_data, val_data, test_data = data.split_data_by_indices(\n",
    "    all_data, train_indices, val_indices, test_indices\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Defining the featurizer\n",
    "\n",
    "Reactions can be featurized using the ```CondensedGraphOfReactionFeaturizer``` (also labeled ```CGRFeaturizer```).\n",
    "\n",
    "\n",
    "Use ```_mode``` keyword to set the mode by which a reaction should be featurized into a ```MolGraph```.\n",
    "\n",
    "Options are can be found with ```featurizers.RxnMode.keys```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "REAC_PROD\n",
      "REAC_PROD_BALANCE\n",
      "REAC_DIFF\n",
      "REAC_DIFF_BALANCE\n",
      "PROD_DIFF\n",
      "PROD_DIFF_BALANCE\n"
     ]
    }
   ],
   "source": [
    "for key in featurizers.RxnMode.keys():\n",
    "    print(key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "featurizer = featurizers.CondensedGraphOfReactionFeaturizer(mode_=\"PROD_DIFF\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get ReactionDatasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dset = data.ReactionDataset(train_data[0], featurizer)\n",
    "scaler = train_dset.normalize_targets()\n",
    "\n",
    "val_dset = data.ReactionDataset(val_data[0], featurizer)\n",
    "val_dset.normalize_targets(scaler)\n",
    "test_dset = data.ReactionDataset(test_data[0], featurizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get dataloaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = data.build_dataloader(train_dset, num_workers=num_workers)\n",
    "val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)\n",
    "test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Change Message-Passing Neural Network (MPNN) inputs here"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Message passing\n",
    "\n",
    "Message passing blocks must be given the shape of the featurizer's outputs.\n",
    "\n",
    "Options are `mp = nn.BondMessagePassing()` or `mp = nn.AtomMessagePassing()`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "fdims = featurizer.shape # the dimensions of the featurizer, given as (atom_dims, bond_dims).\n",
    "mp = nn.BondMessagePassing(*fdims)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Aggregation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ClassRegistry {\n",
      "    'mean': <class 'chemprop.nn.agg.MeanAggregation'>,\n",
      "    'sum': <class 'chemprop.nn.agg.SumAggregation'>,\n",
      "    'norm': <class 'chemprop.nn.agg.NormAggregation'>\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "print(nn.agg.AggregationRegistry)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "agg = nn.MeanAggregation()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Feed-Forward Network (FFN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ClassRegistry {\n",
      "    'regression': <class 'chemprop.nn.predictors.RegressionFFN'>,\n",
      "    'regression-mve': <class 'chemprop.nn.predictors.MveFFN'>,\n",
      "    'regression-evidential': <class 'chemprop.nn.predictors.EvidentialFFN'>,\n",
      "    'regression-quantile': <class 'chemprop.nn.predictors.QuantileFFN'>,\n",
      "    'classification': <class 'chemprop.nn.predictors.BinaryClassificationFFN'>,\n",
      "    'classification-dirichlet': <class 'chemprop.nn.predictors.BinaryDirichletFFN'>,\n",
      "    'multiclass': <class 'chemprop.nn.predictors.MulticlassClassificationFFN'>,\n",
      "    'multiclass-dirichlet': <class 'chemprop.nn.predictors.MulticlassDirichletFFN'>,\n",
      "    'spectral': <class 'chemprop.nn.predictors.SpectralFFN'>\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "print(nn.PredictorRegistry)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "ffn = nn.RegressionFFN(output_transform=output_transform)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Batch norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_norm = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ClassRegistry {\n",
      "    'mse': <class 'chemprop.nn.metrics.MSE'>,\n",
      "    'mae': <class 'chemprop.nn.metrics.MAE'>,\n",
      "    'rmse': <class 'chemprop.nn.metrics.RMSE'>,\n",
      "    'bounded-mse': <class 'chemprop.nn.metrics.BoundedMSE'>,\n",
      "    'bounded-mae': <class 'chemprop.nn.metrics.BoundedMAE'>,\n",
      "    'bounded-rmse': <class 'chemprop.nn.metrics.BoundedRMSE'>,\n",
      "    'r2': <class 'chemprop.nn.metrics.R2Score'>,\n",
      "    'binary-mcc': <class 'chemprop.nn.metrics.BinaryMCCMetric'>,\n",
      "    'multiclass-mcc': <class 'chemprop.nn.metrics.MulticlassMCCMetric'>,\n",
      "    'roc': <class 'chemprop.nn.metrics.BinaryAUROC'>,\n",
      "    'prc': <class 'chemprop.nn.metrics.BinaryAUPRC'>,\n",
      "    'accuracy': <class 'chemprop.nn.metrics.BinaryAccuracy'>,\n",
      "    'f1': <class 'chemprop.nn.metrics.BinaryF1Score'>\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "print(nn.metrics.MetricRegistry)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metric_list = [nn.metrics.RMSE(), nn.metrics.MAE()] \n",
    "# Only the first metric is used for training and early stopping"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Construct MPNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MPNN(\n",
       "  (message_passing): BondMessagePassing(\n",
       "    (W_i): Linear(in_features=134, out_features=300, bias=False)\n",
       "    (W_h): Linear(in_features=300, out_features=300, bias=False)\n",
       "    (W_o): Linear(in_features=406, out_features=300, bias=True)\n",
       "    (dropout): Dropout(p=0.0, inplace=False)\n",
       "    (tau): ReLU()\n",
       "    (V_d_transform): Identity()\n",
       "    (graph_transform): Identity()\n",
       "  )\n",
       "  (agg): MeanAggregation()\n",
       "  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (predictor): RegressionFFN(\n",
       "    (ffn): MLP(\n",
       "      (0): Sequential(\n",
       "        (0): Linear(in_features=300, out_features=300, bias=True)\n",
       "      )\n",
       "      (1): Sequential(\n",
       "        (0): ReLU()\n",
       "        (1): Dropout(p=0.0, inplace=False)\n",
       "        (2): Linear(in_features=300, out_features=1, bias=True)\n",
       "      )\n",
       "    )\n",
       "    (criterion): MSE(task_weights=[[1.0]])\n",
       "    (output_transform): UnscaleTransform()\n",
       "  )\n",
       "  (X_d_transform): Identity()\n",
       "  (metrics): ModuleList(\n",
       "    (0): RMSE(task_weights=[[1.0]])\n",
       "    (1): MAE(task_weights=[[1.0]])\n",
       "    (2): MSE(task_weights=[[1.0]])\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list)\n",
    "mpnn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training and testing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: False, used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    }
   ],
   "source": [
    "trainer = pl.Trainer(\n",
    "    logger=False,\n",
    "    enable_checkpointing=True,  # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.\n",
    "    enable_progress_bar=True,\n",
    "    accelerator=\"auto\",\n",
    "    devices=1,\n",
    "    max_epochs=20,  # number of epochs to train for\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/knathan/chemprop/examples/checkpoints exists and is not empty.\n",
      "Loading `train_dataloader` to estimate number of stepping batches.\n",
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n",
      "\n",
      "  | Name            | Type               | Params | Mode \n",
      "---------------------------------------------------------------\n",
      "0 | message_passing | BondMessagePassing | 252 K  | train\n",
      "1 | agg             | MeanAggregation    | 0      | train\n",
      "2 | bn              | BatchNorm1d        | 600    | train\n",
      "3 | predictor       | RegressionFFN      | 90.6 K | train\n",
      "4 | X_d_transform   | Identity           | 0      | train\n",
      "5 | metrics         | ModuleList         | 0      | train\n",
      "---------------------------------------------------------------\n",
      "343 K     Trainable params\n",
      "0         Non-trainable params\n",
      "343 K     Total params\n",
      "1.374     Total estimated model params size (MB)\n",
      "25        Modules in train mode\n",
      "0         Modules in eval mode\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0:   0%|          | 0/2 [00:00<?, ?it/s]                             "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 19: 100%|██████████| 2/2 [00:00<00:00,  7.13it/s, train_loss_step=0.0888, val_loss=0.937, train_loss_epoch=0.0274]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=20` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 19: 100%|██████████| 2/2 [00:00<00:00,  6.38it/s, train_loss_step=0.0888, val_loss=0.937, train_loss_epoch=0.0274]\n"
     ]
    }
   ],
   "source": [
    "trainer.fit(mpnn, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 59.59it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test/mae          </span>│<span style=\"color: #800080; text-decoration-color: #800080\">     1.111189842224121     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test/rmse         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    1.4387098550796509     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test/mae         \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m    1.111189842224121    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test/rmse        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   1.4387098550796509    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "results = trainer.test(mpnn, test_loader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chemprop",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
